{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Batched Single Step Environment in Camel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Single Step environments are the most widespread type of environment when doing RL with an LLM as policy.\n",
    "\n",
    "It's called *single step* environment, because the agent only does one step. It gets a question sampled from the dataset (the initial state / observation) and then answers. The answer is then scored according to the reward function. Recently, rules-based reward functions, i.e. functions without any learnable parameters, have been successfully used to do RL with LLMs as as policy.\n",
    "\n",
    "Since many RL algorithms (such as GRPO) need multiple rollouts at each step, batching is important to guarantee concurrency / parallelism. This notebook will show how to use batched environments.\n",
    "\n",
    "First, we have to load a dataset from which we will sample questions. The dataset can be either a `StaticDataset`, which is finite and the length is known at runtime, or it can be a `BaseGenerator`, which is an infinite supply of question - answer pairs, synthetically generated in some way (depending on the implementation).\n",
    "\n",
    "For the sake of simplicity, we will start by loading the MATH dataset, remove unnecessary columns and rename the remaining ones, such that we can easily turn it into a `StaticDataset`, which `SingleStepEnv` can deal with. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, install the CAMEL package with all its dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install camel-ai[all]==0.2.46"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/student/Code/camel/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['question', 'final_answer']\n",
      "Example datapoint: question='Let \\\\[f(x) = \\\\left\\\\{\\n\\\\begin{array}{cl} ax+3, &\\\\text{ if }x>2, \\\\\\\\\\nx-5 &\\\\text{ if } -2 \\\\le x \\\\le 2, \\\\\\\\\\n2x-b &\\\\text{ if } x <-2.\\n\\\\end{array}\\n\\\\right.\\\\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).' final_answer='For the piecewise function to be continuous, the cases must \"meet\" at $2$ and $-2$. For example, $ax+3$ and $x-5$ must be equal when $x=2$. This implies $a(2)+3=2-5$, which we solve to get $2a=-6 \\\\Rightarrow a=-3$. Similarly, $x-5$ and $2x-b$ must be equal when $x=-2$. Substituting, we get $-2-5=2(-2)-b$, which implies $b=3$. So $a+b=-3+3=\\\\boxed{0}$.' rationale=None metadata=None\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "from camel.datasets import StaticDataset\n",
    "from camel.logger import get_logger\n",
    "\n",
    "logger = get_logger(__name__)\n",
    "\n",
    "\n",
    "dataset = load_dataset(\"EleutherAI/hendrycks_math\", \"algebra\")\n",
    "\n",
    "# Preprocess\n",
    "dataset[\"train\"] = dataset[\"train\"].rename_column(\"problem\", \"question\")\n",
    "dataset[\"train\"] = dataset[\"train\"].rename_column(\"solution\", \"final_answer\")\n",
    "dataset[\"train\"] = dataset[\"train\"].remove_columns([\"type\", \"level\"])\n",
    "\n",
    "# This should now print \"['question', 'final_answer']\"\n",
    "print(dataset[\"train\"].column_names)\n",
    "seed_dataset = StaticDataset(dataset['train'])\n",
    "\n",
    "print(\"Example datapoint:\", seed_dataset[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we will have to define an *extractor*. An extractor takes the LLM response and extracts the verifiable part out of it. Extractors can be initialized with different strategies which modifies the extraction behavior.\n",
    "\n",
    "In the case of the MATH dataset, the final answer is wrapped inside a `\\boxed{...}`, hence we should use the pre-built `BoxedStrategy`.\n",
    "\n",
    "Sadly, MATH answers are rather complicated and a more general Math verifier to compare, for example, equations has not yet been implemented. Hence, we shall prune the dataset to only contain those rows where the content of `\\boxed{...}` is an int. For the sake of simplicity, we shall also prune the ground truthes to the direct answer (such that they are python expressions). That way, we can do simple verification using the vanilla PythonVerifier!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of datapoints with integer answers: 1228\n"
     ]
    }
   ],
   "source": [
    "from camel.extractors import BaseExtractor, BoxedStrategy\n",
    "\n",
    "# Initialize extractor\n",
    "extractor = BaseExtractor([[BoxedStrategy()]])\n",
    "await extractor.setup()\n",
    "\n",
    "valid_datapoints = []\n",
    "\n",
    "# Iterate through dataset, checking for datapoints with integer answers\n",
    "for datapoint in seed_dataset:\n",
    "    extracted_value = await extractor.extract(response=datapoint.final_answer)\n",
    "\n",
    "    if not extracted_value:\n",
    "        continue\n",
    "    if extracted_value.isdigit() or (\n",
    "        extracted_value.startswith('-') and extracted_value[1:].isdigit()\n",
    "    ):\n",
    "        valid_datapoints.append(\n",
    "            {\n",
    "                \"question\": datapoint.question,\n",
    "                \"final_answer\": extracted_value,\n",
    "            }\n",
    "        )\n",
    "\n",
    "# We should now have `1228` valid data points.\n",
    "print(f\"Number of datapoints with integer answers: {len(valid_datapoints)}\")\n",
    "\n",
    "filtered_dataset = StaticDataset(valid_datapoints, seed=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a Python verifier to later compare answers. Since we are reusing the same extractor from before, the PythonVerifier will expect solutions to be wrapped in `\\boxed{...}`, too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from camel.verifiers import PythonVerifier\n",
    "\n",
    "verifier = PythonVerifier(extractor=extractor)\n",
    "await verifier.setup(uv=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's now initialize the single step environment with our filtered dataset and our verifier. The verifier will later help with the correctness reward \n",
    "\n",
    "We can then call `env.reset(batch_size=4)` to draw from the initial state distribution (the dataset) and return `batch_size` many observations, which can then be fed into the agent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2025-03-24 10:50:32,920 - camel.camel.environments.single_step - WARNING - reset() called on un-setup environment. Setting up...\n",
      "question='The values of a function $f(x)$ are given below:\\n\\n\\\\begin{tabular}{|c||c|c|c|c|c|} \\\\hline $x$ & 3 & 4 & 5 & 6 & 7 \\\\\\\\ \\\\hline $f(x)$ & 10 & 17 & 26 & 37 & 50 \\\\\\\\ \\\\hline \\\\end{tabular}Evaluate $f^{-1}\\\\left(f^{-1}(50)\\\\times f^{-1}(10)+f^{-1}(26)\\\\right)$.' context={} metadata={}\n",
      "question='Find the product of the solutions of: $|y|=2(|y|-1)$.' context={} metadata={}\n",
      "question='Let $A,B$ be the points on the coordinate plane with coordinates $(t-4,-1)$ and $(-2,t+3)$, respectively. The square of the distance between the midpoint of $\\\\overline{AB}$ and an endpoint of $\\\\overline{AB}$ is equal to $t^2/2$. What is the value of $t$?' context={} metadata={}\n",
      "question='An amoeba is placed in a puddle one day, and on that same day it splits into two amoebas.  The next day, each new amoeba splits into two new amoebas, and so on, so that each day every living amoeba splits into two new amoebas. After one week, how many amoebas are in the puddle? (Assume the puddle has no amoebas before the first one is placed in the puddle.)' context={} metadata={}\n"
     ]
    }
   ],
   "source": [
    "from camel.environments import Action, SingleStepEnv\n",
    "\n",
    "env = SingleStepEnv(filtered_dataset, verifier)\n",
    "\n",
    "obs = await env.reset(batch_size=4, seed=42)\n",
    "\n",
    "for ob in obs:\n",
    "    print(ob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The agent would then process these observation and select an action for each observation, which it would feed into the `step` function. An action in this case would simply be the answer to the question, wrapped in `\\boxed{}` (since we initialized our verifier with an extractor that extracts from `\\boxed{...}`).\n",
    "\n",
    "Since we are dealing with batches here, it's assign an index to each question, such that it matches up with the observation that the observation came from. This way, we support microbatching and out-of-order execution!\n",
    "\n",
    "Let's suppose we deal with $2$ microbatches in reverse order:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[StepResult(observation=Observation(question='Episode ended. This is just a placeholder.', context={}, metadata=None), reward=10.0, rewards_dict={'correctness': 10.0}, done=True, info={'proposed_solution': '\\\\boxed{-5}', 'verification_result': VerificationResult(status=<VerificationOutcome.SUCCESS: 'success'>, result='-5', duration=0.00016641616821289062, timestamp=datetime.datetime(2025, 3, 24, 10, 50, 32, 926995), metadata={'attempt': 1}, error_message=None), 'state': DataPoint(question='Let $A,B$ be the points on the coordinate plane with coordinates $(t-4,-1)$ and $(-2,t+3)$, respectively. The square of the distance between the midpoint of $\\\\overline{AB}$ and an endpoint of $\\\\overline{AB}$ is equal to $t^2/2$. What is the value of $t$?', final_answer='-5', rationale=None, metadata=None)}),\n",
       " StepResult(observation=Observation(question='Episode ended. This is just a placeholder.', context={}, metadata=None), reward=10.0, rewards_dict={'correctness': 10.0}, done=True, info={'proposed_solution': '\\\\boxed{128}', 'verification_result': VerificationResult(status=<VerificationOutcome.SUCCESS: 'success'>, result='128', duration=0.00016999244689941406, timestamp=datetime.datetime(2025, 3, 24, 10, 50, 32, 927014), metadata={'attempt': 1}, error_message=None), 'state': DataPoint(question='An amoeba is placed in a puddle one day, and on that same day it splits into two amoebas.  The next day, each new amoeba splits into two new amoebas, and so on, so that each day every living amoeba splits into two new amoebas. After one week, how many amoebas are in the puddle? (Assume the puddle has no amoebas before the first one is placed in the puddle.)', final_answer='128', rationale=None, metadata=None)})]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "microbatch1 = [Action(index=2, llm_response=\"\\\\boxed{-5}\"), Action(index=3, llm_response=\"\\\\boxed{128}\")]\n",
    "\n",
    "await env.step(microbatch1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have already received rewards for actions 2 and 3 of our environment. Let's next finish this environment. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is the batch done?: False\n"
     ]
    }
   ],
   "source": [
    "print(f\"Is the batch done?: {env._batch_done()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[StepResult(observation=Observation(question='Episode ended. This is just a placeholder.', context={}, metadata=None), reward=10.0, rewards_dict={'correctness': 10.0}, done=True, info={'proposed_solution': '\\\\boxed{5}', 'verification_result': VerificationResult(status=<VerificationOutcome.SUCCESS: 'success'>, result='5', duration=0.00030922889709472656, timestamp=datetime.datetime(2025, 3, 24, 10, 50, 32, 942092), metadata={'attempt': 1}, error_message=None), 'state': DataPoint(question='The values of a function $f(x)$ are given below:\\n\\n\\\\begin{tabular}{|c||c|c|c|c|c|} \\\\hline $x$ & 3 & 4 & 5 & 6 & 7 \\\\\\\\ \\\\hline $f(x)$ & 10 & 17 & 26 & 37 & 50 \\\\\\\\ \\\\hline \\\\end{tabular}Evaluate $f^{-1}\\\\left(f^{-1}(50)\\\\times f^{-1}(10)+f^{-1}(26)\\\\right)$.', final_answer='5', rationale=None, metadata=None)}),\n",
       " StepResult(observation=Observation(question='Episode ended. This is just a placeholder.', context={}, metadata=None), reward=10.0, rewards_dict={'correctness': 10.0}, done=True, info={'proposed_solution': '\\\\boxed{-4}', 'verification_result': VerificationResult(status=<VerificationOutcome.SUCCESS: 'success'>, result='-4', duration=0.0003139972686767578, timestamp=datetime.datetime(2025, 3, 24, 10, 50, 32, 942131), metadata={'attempt': 1}, error_message=None), 'state': DataPoint(question='Find the product of the solutions of: $|y|=2(|y|-1)$.', final_answer='-4', rationale=None, metadata=None)})]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "microbatch2 = [Action(index=0, llm_response=\"\\\\boxed{5}\"), Action(index=1, llm_response=\"\\\\boxed{-4}\")]\n",
    "\n",
    "await env.step(microbatch2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Is the batch done?: True\n"
     ]
    }
   ],
   "source": [
    "print(f\"Is the batch done?: {env._batch_done()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the output of the `step` function contains the next observation (which in this case is just a placeholder, since the episode is over), a reward, as well as a reward dict, showing exactly which rubric brought which reward, a `done` flag, indicating that the episode is over and some additional info.\n",
    "\n",
    "In this case, we get a reward of $10$, which is the reward for a correct final answer. This can be accessed and changed via the `self.ACCURACY_REWARD` attribute.\n",
    "\n",
    "Since we did not implement any other reward components, such as a formatting reward, the accuracy reward is our total reward.\n",
    "\n",
    "This is how to use the batched Single Step environment!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
